-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Fixes for General Additive Models #743
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ulations to share one function.
… Adding a validation on the input value.
…eeBinaryClassification Loss to take the sigmoid parameter as input; FastTree uses the default, stays the same. Gam uses Unity. Refactored GamRegressor and GamClassifier into their own files. Added tests to verify Train loss and validation metrics.
…. Adding XML Docs
… 2 into the sigmoid parameter; this allows GAMs to output features on the scale of the logit.
…se of the calibrator.
…no graph would be defined. As the graph is still accessed, it must be defined to zero in such cases.
@@ -8,6 +8,11 @@ | |||
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||
</PropertyGroup> | |||
|
|||
<ItemGroup> | |||
<None Remove="GamClassification.cs" /> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
None Remove="GamClassification.cs" [](start = 5, length = 34)
what is this? Remove
@@ -27,6 +32,8 @@ | |||
<Compile Include="FastTreeRegression.cs" /> | |||
<Compile Include="FastTree.cs" /> | |||
<Compile Include="FastTreeTweedie.cs" /> | |||
<Compile Include="GamClassification.cs" /> | |||
<Compile Include="GamRegression.cs" /> | |||
<Compile Include="GamTrainer.cs" /> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually this is somewhat peculiar. I know you didn't do this, but why is the explicit listing in this project necessary at all? Does anything break if you get rid of this entire ItemGroup
with all the <Compile Includes
?
Thanks for doing this @rogancarr -- should I infer from the lack of baseline changes that no tests were migrated to test GAM? I think we probably ought to do some here -- we have some standard regression problems and whatnot. |
{ | ||
using (var env = new TlcEnvironment()) | ||
{ | ||
var trainFile = "binary_sine_logistic_10k.tsv"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
binary_sine_logistic_10k [](start = 33, length = 24)
Just a reminder what you need to check in this file into repo, or use build actions to download it.
(if you want to check it in, you also need CELA approval)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @Ivanidzo4ka , I believe that this is a synthetic dataset, so probably CELA is not necessary, since it does not come from any real external source? Or am I wrong?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a synthetic dataset. If we want to add this test to the repo, then I'll add the file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But more important question is: Do we need this test?
This is really a test on the internal state of the learner that I used for debugging. If we add baseline tests for GAMs
, then it seems like we don't need this test, and we may not want it, since it relies on the internal state.
Yes, my best guess is that this line: |
New commit addressing comments on the PR:
Next:
|
… removed TrainingInfo (and therefor the test to validate training loss); removed post-training scoring update; made the statistics calculator interface internal to FastTree.
New commit addressing offline reviews:
Next:
|
@@ -248,7 +248,7 @@ public void CopyFeatureHistogram(int subfeatureIndex, ref PerBinStats[] hist) | |||
double sumGTTargets = 0.0; | |||
double sumGTWeights = eps; | |||
int gtCount = 0; | |||
sumWeights = 2 * eps; | |||
sumWeights += 2 * eps; | |||
double gainShift = leafCalculator.GetLeafSplitGain(totalCount, sumTargets, sumWeights); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice
New commits
I performed two performance tests, and sent results to reviewers offline:
Test Results:
|
In the build, SDCA failed a test on Linux Debug -- this seems unrelated to the changes in this PR -- is it possible to rerun the tests? Edit: I reran the tests. |
{ | ||
[Argument(ArgumentType.AtMostOnce, HelpText = "Metric for pruning. (For regression, 1: L1, 2:L2; default L2)", ShortName = "pmetric")] | ||
[TGUI(Description = "Metric for pruning. (For regression, 1: L1, 2:L2; default L2")] | ||
public int PruningMetrics = 2; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
int PruningMetrics = 2 [](start = 19, length = 22)
wow, really? Can we maybe have enum here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a holdover from FastTree
-- this is copied over from FastTreeRegression
. It's a parameter to the RegressionTest
class.
If possible, I'd like to file this as a separate issue and fix as a separate PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @rogancarr , I'm requesting a rebuild due to an unrelated SDCA test failure.
Thanks for the review @TomFinley and @Zruty0 ! I addressed all but one issue and responded to one in the comments here. Note that WIP is flagged, but this is not WIP -- I had a WIP commit that I didn't remove because we will squash-merge this branch. |
This PR addresses a number of issues with the
General Additive Model
(GAM
) trainer. In particular, it addresses the issues with theGAM Classifier
not fitting nor producing a probability, and adds support for validation pruning, summary text, and centering of the feature effects around a mean response (e.g. intercept).Additionally, the PR addresses some minor issues in the codebase, like
GAMs
using copy-and-paste versions ofFastTree
code, unnecessarilypublic
attributes, unused arguments, and splits theBinaryClassifier
andRegressor
into separate files.The changes are as follows:
a. Switched to using
ScoreTrackers
to keep track of scores during boostedb. Save boost iterations as individual graphs (n_features x n_iterations x n_boosts) for pruning.
GAM Classifier
to use a small learning rate (FastTree Gradient of Logistic Loss prohibits small learning rates #741)Updated the
FastTreeBinaryClassification
logistic loss gradient to take thesigmoid
parameter as input:GAM Binary Classifier
now uses unity;FastTree Binary Classifier
uses the same default of2*learning rate
(no change). Optionally, we can plumb this to the FastTree arguments if we show an experimental gain; We can also experiment to see if thesigmoid
parameter gives gains for GAMs (i.e. by slowing learning even more).GamClassifier
to produce probabilities (General Additive Models (GAM) for Classification learn the logit, not class probability #738).This meant changing the various interfaces.
GAMRegressor
andGAMClassifier
into their own files.The one file for all GAM trainers had gotten a bit long, and this change is consistent with ML.NET tradition.
GAMPredictor
to produce training statistics and the feature table (General Additive Model (GAM) has no summary to extract features at runtime #742).The
GAM
routines used a copy-and-paste of internalFastTree
components. To fight entropy, these were refactored to use the same central calculation, such that it is verified and validated by theFastTree
unit and end-to-end tests.Fixes #738
Fixes #739
Fixes #740
Fixes #741
Fixes #742